import argparse
import json
import logging
import math
import re
import time
import typing
from pathlib import Path

import flax
import numpy as np
import orbax
import safetensors.torch
import torch
from recurrentgemma import jax as recurrentgemma_jax
from transformers import AutoConfig, AutoModelForCausalLM

import tensorrt_llm
from tensorrt_llm import logger
from tensorrt_llm._utils import (numpy_to_torch, str_dtype_to_torch,
                                 torch_to_numpy)

LOGGER = logging.getLogger("convert_checkpoint")


def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument("--ckpt_type", type=str, choices=["jax", "hf"])
    parser.add_argument("--model_dir", type=Path, default=None)
    parser.add_argument("--world_size",
                        type=int,
                        default=1,
                        help="world size, only support tensor parallelism now")
    parser.add_argument("--dtype",
                        type=str,
                        default="float16",
                        choices=["float32", "bfloat16", "float16"])
    parser.add_argument(
        "--output_dir",
        type=Path,
        default="recurrentgemma_tllm_checkpoint",
        help="The path to save the recurrentgemma TensorRT-LLM checkpoint")
    parser.add_argument("--log_level", type=str, default="info")
    args = parser.parse_args()
    return args


class JAXParser:

    def load_parameters(self, checkpoint_path: Path):
        checkpoint_path = checkpoint_path.absolute()
        checkpointer = orbax.checkpoint.PyTreeCheckpointer()
        params = checkpointer.restore(checkpoint_path)
        return params

    def embedding_weights(self, ckpt_params):
        return ckpt_params["embedder"]["input_embedding"]

    def get_config(self, checkpoint_path, ckpt_params):
        config = recurrentgemma_jax.GriffinConfig.from_flax_params_or_variables(
            ckpt_params)._asdict()
        if config["lru_width"] is None:
            config["lru_width"] = config["width"]
        layer_types = []
        for p in config["block_types"]:
            if p == recurrentgemma_jax.TemporalBlockType.ATTENTION:
                layer_types.append("attention")
            else:
                layer_types.append("recurrent")
        config["block_types"] = layer_types
        config["hidden_size"] = config.pop("width")
        config["num_attention_heads"] = config.pop("num_heads")
        config["intermediate_size"] = config.pop("mlp_expanded_width")
        config["num_hidden_layers"] = len(config["block_types"])
        return config

    def rename_to_trt_llm(self, name: str):
        """Rename a recurrentgemma parameter name by the corresponding TRT-LLM style name."""
        sub_patterns = (
            (r"embedder.input_embedding", r"vocab_embedding.weight"),
            (r"blocks.(\d+).channel_pre_norm.scale",
             r"layers.\1.post_layernorm.weight"),
            (r"blocks.(\d+).temporal_pre_norm.scale",
             r"layers.\1.input_layernorm.weight"),
            (r"blocks.(\d+).recurrent_block.conv_1d.w",
             r"layers.\1.recurrent.conv1d.weight"),
            (r"blocks.(\d+).recurrent_block.conv_1d.b",
             r"layers.\1.recurrent.conv1d.bias"),
            (r"blocks.(\d+).recurrent_block.linear_out.kernel",
             r"layers.\1.recurrent.linear_out.weight"),
            (r"blocks.(\d+).recurrent_block.linear_out.bias",
             r"layers.\1.recurrent.linear_out.bias"),
            (r"blocks.(\d+).recurrent_block.linear_x.kernel",
             r"layers.\1.recurrent.linear_x.weight"),
            (r"blocks.(\d+).recurrent_block.linear_x.bias",
             r"layers.\1.recurrent.linear_x.bias"),
            (r"blocks.(\d+).recurrent_block.linear_y.kernel",
             r"layers.\1.recurrent.linear_y.weight"),
            (r"blocks.(\d+).recurrent_block.linear_y.bias",
             r"layers.\1.recurrent.y_bias"),
            (r"blocks.(\d+).recurrent_block.rg_lru.a_gate.w",
             r"layers.\1.recurrent.rg_lru.recurrent_gate.weight"),
            (r"blocks.(\d+).recurrent_block.rg_lru.a_gate.b",
             r"layers.\1.recurrent.rg_lru.recurrent_gate.bias"),
            (r"blocks.(\d+).recurrent_block.rg_lru.input_gate.w",
             r"layers.\1.recurrent.rg_lru.input_gate.weight"),
            (r"blocks.(\d+).recurrent_block.rg_lru.input_gate.b",
             r"layers.\1.recurrent.rg_lru.input_gate.bias"),
            (r"blocks.(\d+).recurrent_block.rg_lru.a_param",
             r"layers.\1.recurrent.rg_lru.recurrent_param"),
            (r"blocks.(\d+).mlp_block.ffw_up.w", r"layers.\1.mlp.fc.weight"),
            (r"blocks.(\d+).mlp_block.ffw_up.b", None),
            (r"blocks.(\d+).mlp_block.ffw_down.kernel",
             r"layers.\1.mlp.proj.weight"),
            (r"blocks.(\d+).mlp_block.ffw_down.bias",
             r"layers.\1.mlp.proj.bias"),
            (r"blocks.(\d+).attention_block.proj_q.kernel",
             r"layers.\1.attention.qkv.weight"),
            (r"blocks.(\d+).attention_block.proj_k.kernel", None),
            (r"blocks.(\d+).attention_block.proj_v.kernel", None),
            (r"blocks.(\d+).attention_block.proj_final.kernel",
             r"layers.\1.attention.dense.weight"),
            (r"blocks.(\d+).attention_block.proj_final.bias",
             r"layers.\1.attention.dense.bias"),
            (r"final_norm.scale", r"ln_f.weight"),
        )

        for source, target in sub_patterns:
            if re.match(source, name):
                if target is None:
                    return target
                else:
                    name = re.sub(source, target, name)
                    return ".".join(("transformer", name))
        else:
            raise ValueError(f"Don't know how to rename {name}")

    def flatten_params(self, params):
        new_params = flax.traverse_util.flatten_dict(params, sep=".")
        # if the dtype is bfloat16, cast to float32
        for k in new_params:
            if new_params[k].dtype != np.float32 and new_params[
                    k].dtype != np.float16:
                new_params[k] = new_params[k].astype(np.float32)
        return new_params


class HfParser:

    def load_parameters(self, checkpoint_path: Path):
        hf_model = AutoModelForCausalLM.from_pretrained(
            checkpoint_path,
            device_map="auto",
            torch_dtype="auto",
        )
        model_params = dict(hf_model.named_parameters())
        return model_params

    def embedding_weights(self, ckpt_params):
        return ckpt_params["model.embed_tokens.weight"]

    def get_config(self, checkpoint_path, ckpt_params):
        checkpoint_path = checkpoint_path.absolute()
        hf_config = AutoConfig.from_pretrained(
            checkpoint_path, trust_remote_code=True).to_dict()
        hf_config["block_types"] = hf_config.pop("_block_types")
        hf_config["intermediate_size"] = hf_config["intermediate_size"] // 2
        return hf_config

    def rename_to_trt_llm(self, name: str):
        """Rename a recurrentgemma parameter name by the corresponding TRT-LLM style name."""
        sub_patterns = (
            (r"model.embed_tokens.weight", r"vocab_embedding.weight"),
            (r"model.layers.(\d+).temporal_pre_norm.weight",
             r"layers.\1.input_layernorm.weight"),
            (r"model.layers.(\d+).channel_pre_norm.weight",
             r"layers.\1.post_layernorm.weight"),
            (r"model.layers.(\d+).temporal_block.conv_1d.weight",
             r"layers.\1.recurrent.conv1d.weight"),
            (r"model.layers.(\d+).temporal_block.conv_1d.bias",
             r"layers.\1.recurrent.conv1d.bias"),
            (r"model.layers.(\d+).temporal_block.linear_out.weight",
             r"layers.\1.recurrent.linear_out.weight"),
            (r"model.layers.(\d+).temporal_block.linear_out.bias",
             r"layers.\1.recurrent.linear_out.bias"),
            (r"model.layers.(\d+).temporal_block.linear_x.weight",
             r"layers.\1.recurrent.linear_x.weight"),
            (r"model.layers.(\d+).temporal_block.linear_x.bias",
             r"layers.\1.recurrent.linear_x.bias"),
            (r"model.layers.(\d+).temporal_block.linear_y.weight",
             r"layers.\1.recurrent.linear_y.weight"),
            (r"model.layers.(\d+).temporal_block.linear_y.bias",
             r"layers.\1.recurrent.y_bias"),
            (r"model.layers.(\d+).temporal_block.rg_lru.recurrent_gate_weight",
             r"layers.\1.recurrent.rg_lru.recurrent_gate.weight"),
            (r"model.layers.(\d+).temporal_block.rg_lru.recurrent_gate_bias",
             r"layers.\1.recurrent.rg_lru.recurrent_gate.bias"),
            (r"model.layers.(\d+).temporal_block.rg_lru.input_gate_weight",
             r"layers.\1.recurrent.rg_lru.input_gate.weight"),
            (r"model.layers.(\d+).temporal_block.rg_lru.input_gate_bias",
             r"layers.\1.recurrent.rg_lru.input_gate.bias"),
            (r"model.layers.(\d+).temporal_block.rg_lru.recurrent_param",
             r"layers.\1.recurrent.rg_lru.recurrent_param"),
            (r"model.layers.(\d+).mlp_block.up_proj.weight",
             r"layers.\1.mlp.gate.weight"),
            (r"model.layers.(\d+).mlp_block.up_proj.bias",
             r"layers.\1.mlp.gate.bias"),
            (r"model.layers.(\d+).mlp_block.gate_proj.weight",
             r"layers.\1.mlp.fc.weight"),
            (r"model.layers.(\d+).mlp_block.gate_proj.bias",
             r"layers.\1.mlp.fc.bias"),
            (r"model.layers.(\d+).mlp_block.down_proj.weight",
             r"layers.\1.mlp.proj.weight"),
            (r"model.layers.(\d+).mlp_block.down_proj.bias",
             r"layers.\1.mlp.proj.bias"),
            (r"model.layers.(\d+).temporal_block.q_proj.weight",
             r"layers.\1.attention.qkv.weight"),
            (r"model.layers.(\d+).temporal_block.k_proj.weight", None),
            (r"model.layers.(\d+).temporal_block.v_proj.weight", None),
            (r"model.layers.(\d+).temporal_block.o_proj.weight",
             r"layers.\1.attention.dense.weight"),
            (r"model.layers.(\d+).temporal_block.o_proj.bias",
             r"layers.\1.attention.dense.bias"),
            (r"model.final_norm.weight", r"ln_f.weight"),
        )

        for source, target in sub_patterns:
            if re.match(source, name):
                if target is None:
                    return target
                else:
                    name = re.sub(source, target, name)
                    return ".".join(("transformer", name))
        else:
            raise ValueError(f"Don't know how to rename {name}")

    def flatten_params(self, params):
        f_params = {}
        for k, v in params.items():
            if v.dtype == torch.bfloat16:
                v = v.float()
            f_params[k] = torch_to_numpy(v)
        return f_params


CKPT_PARSER = {
    "jax": JAXParser,
    "hf": HfParser,
}


def split(v, tp_size, idx, dim=0):
    if tp_size == 1:
        return v
    return np.split(v, tp_size, axis=dim)[idx]


def split_matrix_tp(v, tensor_parallel, rank, dim):
    return split(v, tensor_parallel, rank, dim=dim)


def add_trt_llm_weight(weights: typing.Dict[str, np.ndarray],
                       name: str,
                       param: np.ndarray,
                       dtype: typing.Optional[np.dtype] = None):
    assert name not in weights, f"{name} is already added."
    param = numpy_to_torch(param)
    if dtype is not None:
        assert isinstance(dtype,
                          str), f"dtype must be str, but get type {type(dtype)}"
        param = param.to(str_dtype_to_torch(dtype))
    weights[name] = param.contiguous()


def convert_from_checkpoint(
    trt_llm_config: tensorrt_llm.models.modeling_utils.PretrainedConfig,
    model_dir: typing.Union[str, Path],
    ckpt_parser,
    rank=0,
):
    print("Loading weights...")
    tik = time.time()

    tp_rank = rank
    tp_size = trt_llm_config.mapping.tp_size
    intermediate_size = trt_llm_config.intermediate_size
    rnn_hidden_size = trt_llm_config.rnn_hidden_size
    conv_kernel = trt_llm_config.conv_kernel

    weights = {}
    for model_file in [model_dir]:
        LOGGER.debug(f"Loading directory {str(model_file)}...")
        model_params = ckpt_parser.load_parameters(model_file)
        model_params = ckpt_parser.flatten_params(model_params)

        for name, param in model_params.items():
            LOGGER.debug(f"Converting weight {name}...")
            trt_llm_name = ckpt_parser.rename_to_trt_llm(name)
            if trt_llm_name is None:  # omit as used with other params
                continue

            if "proj_q" in name or "q_proj" in name:
                if isinstance(ckpt_parser, JAXParser):
                    k_name = name.replace("proj_q", "proj_k")
                    v_name = name.replace("proj_q", "proj_v")
                    q_param = param.transpose(1, 0)
                    k_param = model_params[k_name].transpose(1, 0)
                    v_param = model_params[v_name].transpose(1, 0)
                else:
                    k_name = name.replace("q_proj", "k_proj")
                    v_name = name.replace("q_proj", "v_proj")
                    q_param = param
                    k_param = model_params[k_name]
                    v_param = model_params[v_name]
                q_param = split_matrix_tp(q_param, tp_size, tp_rank, dim=0)
                qkv_param = np.concatenate([q_param, k_param, v_param], axis=0)
                add_trt_llm_weight(weights, trt_llm_name, qkv_param,
                                   trt_llm_config.dtype)
            elif "ffw_up.w" in name and isinstance(ckpt_parser, JAXParser):
                bias_name = name.replace("ffw_up.w", "ffw_up.b")
                fc_param, gate_param = param[0, ].transpose(
                    1, 0), param[1, ].transpose(1, 0)
                fc_param = split_matrix_tp(fc_param, tp_size, tp_rank, dim=0)
                gate_param = split_matrix_tp(gate_param,
                                             tp_size,
                                             tp_rank,
                                             dim=0)
                fc_bias = model_params[bias_name][0, ].reshape(
                    intermediate_size)
                gate_bias = model_params[bias_name][1, ].reshape(
                    intermediate_size)
                fc_bias = split_matrix_tp(fc_bias, tp_size, tp_rank, dim=0)
                gate_bias = split_matrix_tp(gate_bias, tp_size, tp_rank, dim=0)
                trt_llm_fc_name = trt_llm_name
                trt_llm_gate_name = trt_llm_name.replace(
                    "fc.weight", "gate.weight")
                trt_llm_fc_b_name = trt_llm_name.replace("fc.weight", "fc.bias")
                trt_llm_gate_b_name = trt_llm_name.replace(
                    "fc.weight", "gate.bias")
                add_trt_llm_weight(weights, trt_llm_fc_name, fc_param,
                                   trt_llm_config.dtype)
                add_trt_llm_weight(weights, trt_llm_gate_name, gate_param,
                                   trt_llm_config.dtype)
                add_trt_llm_weight(weights, trt_llm_fc_b_name, fc_bias,
                                   trt_llm_config.dtype)
                add_trt_llm_weight(weights, trt_llm_gate_b_name, gate_bias,
                                   trt_llm_config.dtype)
            elif "conv_1d.w" in name:
                if isinstance(ckpt_parser, JAXParser):
                    param = param.transpose(1,
                                            0).reshape(rnn_hidden_size, 1,
                                                       conv_kernel, 1)
                else:
                    param = param.reshape(rnn_hidden_size, 1, conv_kernel, 1)
                param = split_matrix_tp(param, tp_size, tp_rank, dim=0)
                add_trt_llm_weight(weights, trt_llm_name, param,
                                   trt_llm_config.dtype)
            elif "embedder.input_embedding" in name or "model.embed_tokens" in name:
                if not trt_llm_config.share_embedding_table:
                    lm_head = split_matrix_tp(param, tp_size, tp_rank, dim=0)
                    add_trt_llm_weight(weights, "lm_head.weight",
                                       np.copy(lm_head), trt_llm_config.dtype)
                if trt_llm_config.emb_scale_by_sqrt_dim:
                    param = np.multiply(
                        param.astype(np.float32),
                        math.sqrt(trt_llm_config.hidden_size),
                    )
                if trt_llm_config.use_parallel_embedding:
                    assert trt_llm_config.vocab_size % tp_size == 0
                    param = split_matrix_tp(
                        param,
                        tp_size,
                        tp_rank,
                        dim=trt_llm_config.embedding_sharding_dim,
                    )
                add_trt_llm_weight(weights, trt_llm_name, param,
                                   trt_llm_config.dtype)
            elif any(keyword in name for keyword in (
                    "proj_final.kernel",
                    "ffw_down.kernel",
                    "linear_out.kernel",
                    "o_proj.weight",
                    "down_proj.weight",
                    "linear_out.weight",
            )):
                if isinstance(ckpt_parser, JAXParser):
                    param = param.transpose(1, 0)
                param = split_matrix_tp(param, tp_size, tp_rank, dim=1)
                add_trt_llm_weight(weights, trt_llm_name, param,
                                   trt_llm_config.dtype)
            elif any(keyword in name for keyword in (
                    "linear_x.kernel",
                    "linear_y.kernel",
                    "linear_x.weight",
                    "linear_y.weight",
                    "up_proj.weight",
                    "gate_proj.weight",
            )):
                if isinstance(ckpt_parser, JAXParser):
                    param = param.transpose(1, 0)
                param = split_matrix_tp(param, tp_size, tp_rank, dim=0)
                add_trt_llm_weight(weights, trt_llm_name, param,
                                   trt_llm_config.dtype)
            elif any(keyword in name for keyword in (
                    "linear_x.bias",
                    "linear_y.bias",
                    "rg_lru",
                    "conv_1d.b",
                    "gate_proj.bias",
                    "up_proj.bias",
            )):
                param = split_matrix_tp(param, tp_size, tp_rank, dim=0)
                add_trt_llm_weight(weights, trt_llm_name, param,
                                   trt_llm_config.dtype)
            elif any(keyword in name for keyword in (
                    "channel_pre_norm",
                    "temporal_pre_norm",
                    "final_norm",
            )):
                param = param + 1.0
                add_trt_llm_weight(weights, trt_llm_name, param,
                                   trt_llm_config.dtype)
            elif any(keyword in name for keyword in (
                    "proj_final.bias",
                    "ffw_down.bias",
                    "linear_out.bias",
                    "o_proj.bias",
                    "down_proj.bias",
            )):
                add_trt_llm_weight(weights, trt_llm_name, param,
                                   trt_llm_config.dtype)
            else:
                raise RuntimeError(f"Unhandled {name} module weights")
        del model_params
    print(
        f"Weights loaded. Total time: {time.strftime('%H:%M:%S', time.gmtime(time.time() - tik))}"
    )
    return weights


def convert(worker_rank, args, convert_kwargs):
    for rank in range(worker_rank, args.world_size):
        weights = convert_from_checkpoint(rank=rank, **convert_kwargs)
        safetensors.torch.save_file(weights,
                                    args.output_dir / f"rank{rank}.safetensors")


def main():
    print(tensorrt_llm.__version__)

    args = parse_arguments()
    logger.set_level(args.log_level)
    tik = time.time()

    print(f"Loading source parameters from {args.model_dir.absolute()}")
    ckpt_parser = CKPT_PARSER[args.ckpt_type]()
    ckpt_params = ckpt_parser.load_parameters(args.model_dir)
    input_embedding_weights = ckpt_parser.embedding_weights(ckpt_params)
    ckpt_params_dtype = str(input_embedding_weights.dtype).split(".")[-1]
    ckpt_config = ckpt_parser.get_config(args.model_dir, ckpt_params)

    print(f"Source configuration determined from parameters: {ckpt_config}")

    quant_config = tensorrt_llm.models.modeling_utils.QuantConfig()
    trt_llm_config = tensorrt_llm.models.modeling_utils.PretrainedConfig(
        architecture="RecurrentGemmaForCausalLM",
        dtype=args.dtype or ckpt_params_dtype,
        logits_dtype="float32",
        vocab_size=ckpt_config["vocab_size"],
        # follow the setting of gemma models
        max_position_embeddings=8192,
        hidden_size=ckpt_config["hidden_size"],
        num_hidden_layers=ckpt_config["num_hidden_layers"],
        num_attention_heads=ckpt_config["num_attention_heads"],
        num_key_value_heads=1,
        head_size=ckpt_config["hidden_size"] //
        ckpt_config["num_attention_heads"],
        hidden_act="gelu",
        intermediate_size=ckpt_config["intermediate_size"],
        norm_epsilon=1e-6,
        position_embedding_type="rope_gpt_neox",
        mapping={
            'world_size': args.world_size,
            'tp_size': args.world_size,
            'pp_size': 1
        },
        gpus_per_node=8,
        quantization=quant_config,
        conv_kernel=4,
        state_size=1,
        state_dtype='float32',
        rotary_pct=0.5,
        layer_types=ckpt_config["block_types"],
        rnn_hidden_size=ckpt_config["lru_width"],
        logits_soft_cap=ckpt_config["logits_soft_cap"],
        emb_scale_by_sqrt_dim=ckpt_config["embeddings_scale_by_sqrt_dim"],
    )

    trt_llm_config_dict = trt_llm_config.to_dict()
    print(f"Determined TensorRT-LLM configuration {trt_llm_config_dict}")

    config_path = args.output_dir / "config.json"
    config_path.parent.mkdir(exist_ok=True, parents=True)
    LOGGER.debug(f"Saving TensorRT-LLM configuration to {config_path}")
    with config_path.open("w") as config_file:
        json.dump(trt_llm_config_dict, config_file, indent=4)

    convert_args = dict(trt_llm_config=trt_llm_config,
                        model_dir=args.model_dir,
                        ckpt_parser=ckpt_parser)
    convert(0, args, convert_args)

    elapsed = time.strftime("%H:%M:%S", time.gmtime(time.time() - tik))
    print(f"Total time of converting checkpoints: {elapsed}")


if __name__ == "__main__":
    main()
